Skip to content

Conversation

ligerlac
Copy link

@ligerlac ligerlac commented May 21, 2025

This fixes a bug in the jax_backend.tile() method. Consider the following minimal example:

import jax.numpy as jnp
import pyhf

pyhf.set_backend("jax", default=True)  # works without this line

spec = {
    "channels": [
        {
            "name": "singlechannel",
            "samples": [
                {
                    "name": "signal",
                    "data": jnp.array([0.0, 0.0, 0.0]),
                    "modifiers": [
                        {
                            "name": "mu",
                            "type": "normfactor",
                            "data": None,
                        },
                    ],
                },
            ],
        },
    ],
}

my_model = pyhf.Model(spec, validate=False)

The last line fails with TypeError: tile requires ndarray or scalar arguments, got <class 'list'> at position 0.. However, it works fine when using the numpy backend. The problem stems from differences between np.tile and jnp.tile:

import numpy as np
import jax.numpy as jnp

tensor_in = [[[0, 1, 2]]]
repeats = (0, 1, 1)

np.tile(tensor_in, repeats)  # works fine
jnp.tile(tensor_in, repeats)  # fails with same error message as above
jnp.tile(jnp.array(tensor_in), repeats)  # works fine

Unlike jnp.tile, np.tile implicitly converts the input to the correct type.
This PR ensures tensor_in is a jnp.array to make the behaviour of numpy_backend.tile() and jax_backend.tile() consistent.

@matthewfeickert matthewfeickert changed the title Bugfix: make numpy_backend.tile() and jax_backend.tile() consistent fix: make numpy_backend.tile() and jax_backend.tile() consistent May 22, 2025
@matthewfeickert matthewfeickert added the fix A bug fix label May 22, 2025
@matthewfeickert
Copy link
Member

@ligerlac Thanks for the PR. Today I have been clawing myself out of travel related time dependent TODOs, but I can review this on Thursday (2025-05-22).

I haven't looked/thought about this yet, but I assume that this isn't something unique to tile but more generic to how things are being dealt with in spec validation of pyhf.Model (though maybe if I actually think about the PR the reason would be clear to me). Is this a general solution or more of a targeted use patch?

@matthewfeickert matthewfeickert requested review from a team, kratsg and matthewfeickert and removed request for a team May 22, 2025 06:29
@ligerlac
Copy link
Author

It's more of a patch. You are right, the problem is not unqiue to tile(). There are similar problems with concatenate():

pyhf.set_backend("jax", default=True)  # works without this line

spec = {
    "channels": [
        {
            "name": "singlechannel",
            "samples": [
                {
                    "name": "signal",
                    "data": jnp.array([0.0, 0.0, 0.0]),
                    "modifiers": [
                        {
                            "name": "mu",
                            "type": "normfactor",
                            "data": None,
                        }, 
                    ],
                },
                {
                    "name": "background",
                    "data": jnp.array([0.0, 0.0, 0.0]),  # dummy data
                    "modifiers": [
                        {
                            "name": "correlated_bkg_uncertainty",
                            "type": "histosys",
                            "data": {
                                "hi_data": jnp.array([0.0, 0.0, 0.0]),
                                "lo_data": jnp.array([0.0, 0.0, 0.0]),
                            },
                        },
                    ],
                },
            ],
        },
    ],
}

my_model = pyhf.Model(spec, validate=False)

last line fails with TypeError: concatenate requires ndarray or scalar arguments, got <class 'list'> at position 0.. Again, this can be tracked down to a difference between np.concatenate() and jnp.concatenate():

np.concatenate([[True, True, True]])  # works
jnp.concatenate([[True, True, True]])  # fails
jnp.concatenate(jnp.array([[True, True, True]]))  # works

We could also patch that in the jax backend. But I guess a more elegant solution would be to make sure that each backend is only receiving arguments of the correct type by calling tensorlib.astensor in all the right places (like the _precompute() methods). I'll try to find some time over the weekend to have another look at this (including missings tests).

Copy link

codecov bot commented Oct 14, 2025

Codecov Report

❌ Patch coverage is 0% with 2 lines in your changes missing coverage. Please review.
✅ Project coverage is 98.18%. Comparing base (10488f0) to head (aab16bd).

Files with missing lines Patch % Lines
src/pyhf/tensor/jax_backend.py 0.00% 1 Missing and 1 partial ⚠️
Additional details and impacted files
@@            Coverage Diff             @@
##             main    #2587      +/-   ##
==========================================
- Coverage   98.23%   98.18%   -0.05%     
==========================================
  Files          65       65              
  Lines        4193     4195       +2     
  Branches      591      592       +1     
==========================================
  Hits         4119     4119              
- Misses         45       46       +1     
- Partials       29       30       +1     
Flag Coverage Δ
contrib 97.92% <0.00%> (-0.05%) ⬇️
doctest 98.04% <0.00%> (-0.05%) ⬇️
unittests-3.10 96.23% <0.00%> (-0.05%) ⬇️
unittests-3.11 96.23% <0.00%> (-0.05%) ⬇️
unittests-3.12 96.23% <0.00%> (-0.05%) ⬇️
unittests-3.8 96.23% <0.00%> (-0.05%) ⬇️
unittests-3.9 96.28% <0.00%> (-0.05%) ⬇️

Flags with carried forward coverage won't be shown. Click here to find out more.

☔ View full report in Codecov by Sentry.
📢 Have feedback on the report? Share it here.

🚀 New features to boost your workflow:
  • ❄️ Test Analytics: Detect flaky tests, report on failures, and find test suite problems.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

fix A bug fix

Projects

Status: No status

Development

Successfully merging this pull request may close these issues.

2 participants